{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 626.6648559570312\n",
      "1 625.931640625\n",
      "2 624.5895385742188\n",
      "3 628.6260375976562\n",
      "4 620.8497924804688\n",
      "5 678.1670532226562\n",
      "6 616.9553833007812\n",
      "7 584.82958984375\n",
      "8 611.2676391601562\n",
      "9 611.9619750976562\n",
      "10 584.9210205078125\n",
      "11 392.85406494140625\n",
      "12 350.58795166015625\n",
      "13 600.918701171875\n",
      "14 264.4318542480469\n",
      "15 222.20655822753906\n",
      "16 593.484375\n",
      "17 606.4907836914062\n",
      "18 538.1737670898438\n",
      "19 602.8992309570312\n",
      "20 504.5230407714844\n",
      "21 476.6241760253906\n",
      "22 439.3219299316406\n",
      "23 585.73388671875\n",
      "24 523.3035278320312\n",
      "25 111.19840240478516\n",
      "26 296.1112365722656\n",
      "27 116.33778381347656\n",
      "28 443.4542541503906\n",
      "29 214.4071044921875\n",
      "30 188.0015869140625\n",
      "31 470.4693298339844\n",
      "32 440.1273498535156\n",
      "33 400.4791259765625\n",
      "34 127.97659301757812\n",
      "35 121.3911361694336\n",
      "36 299.3867492675781\n",
      "37 190.69114685058594\n",
      "38 175.23495483398438\n",
      "39 155.50230407714844\n",
      "40 175.8963165283203\n",
      "41 129.99551391601562\n",
      "42 167.50523376464844\n",
      "43 243.75430297851562\n",
      "44 124.16504669189453\n",
      "45 102.92121124267578\n",
      "46 61.74596405029297\n",
      "47 114.27743530273438\n",
      "48 81.23641204833984\n",
      "49 96.09291076660156\n",
      "50 114.64246368408203\n",
      "51 106.58434295654297\n",
      "52 82.4592056274414\n",
      "53 56.97783279418945\n",
      "54 69.61764526367188\n",
      "55 95.17444610595703\n",
      "56 45.247047424316406\n",
      "57 40.36801528930664\n",
      "58 37.14039611816406\n",
      "59 27.57843017578125\n",
      "60 116.99701690673828\n",
      "61 53.940826416015625\n",
      "62 27.53561782836914\n",
      "63 198.87245178222656\n",
      "64 94.19752502441406\n",
      "65 158.45431518554688\n",
      "66 64.4766616821289\n",
      "67 90.5487289428711\n",
      "68 41.20033645629883\n",
      "69 120.66567993164062\n",
      "70 47.448299407958984\n",
      "71 41.87496566772461\n",
      "72 62.02822494506836\n",
      "73 40.8214111328125\n",
      "74 49.975345611572266\n",
      "75 39.3807258605957\n",
      "76 36.9306755065918\n",
      "77 34.88636779785156\n",
      "78 36.56793975830078\n",
      "79 22.985971450805664\n",
      "80 27.525239944458008\n",
      "81 40.438907623291016\n",
      "82 30.0289249420166\n",
      "83 29.858667373657227\n",
      "84 23.56079864501953\n",
      "85 16.799585342407227\n",
      "86 11.022204399108887\n",
      "87 42.2107048034668\n",
      "88 19.32655906677246\n",
      "89 29.57625961303711\n",
      "90 58.701839447021484\n",
      "91 15.094594955444336\n",
      "92 18.096790313720703\n",
      "93 64.97696685791016\n",
      "94 20.944290161132812\n",
      "95 13.000467300415039\n",
      "96 26.43996810913086\n",
      "97 25.50496482849121\n",
      "98 17.699304580688477\n",
      "99 92.36943817138672\n",
      "100 73.4350814819336\n",
      "101 47.88151550292969\n",
      "102 142.7108917236328\n",
      "103 75.86461639404297\n",
      "104 43.130489349365234\n",
      "105 48.32160568237305\n",
      "106 11.0353422164917\n",
      "107 32.72777557373047\n",
      "108 71.67752075195312\n",
      "109 67.05614471435547\n",
      "110 34.504764556884766\n",
      "111 55.64155578613281\n",
      "112 19.97456169128418\n",
      "113 17.78594207763672\n",
      "114 66.19388580322266\n",
      "115 7.607739448547363\n",
      "116 36.73714065551758\n",
      "117 13.75013542175293\n",
      "118 13.007487297058105\n",
      "119 16.35108184814453\n",
      "120 24.795297622680664\n",
      "121 17.1218318939209\n",
      "122 8.694226264953613\n",
      "123 8.923991203308105\n",
      "124 21.429460525512695\n",
      "125 22.194068908691406\n",
      "126 17.538957595825195\n",
      "127 8.380979537963867\n",
      "128 23.679349899291992\n",
      "129 6.983810901641846\n",
      "130 6.797555923461914\n",
      "131 9.6456298828125\n",
      "132 9.129494667053223\n",
      "133 29.37210464477539\n",
      "134 4.590899467468262\n",
      "135 10.769716262817383\n",
      "136 6.500424385070801\n",
      "137 19.015993118286133\n",
      "138 3.4738221168518066\n",
      "139 15.679688453674316\n",
      "140 3.8937432765960693\n",
      "141 4.80169153213501\n",
      "142 11.772103309631348\n",
      "143 12.399144172668457\n",
      "144 11.034842491149902\n",
      "145 8.597681999206543\n",
      "146 6.29863166809082\n",
      "147 5.823895454406738\n",
      "148 7.416306972503662\n",
      "149 6.400707721710205\n",
      "150 2.2959728240966797\n",
      "151 4.654292106628418\n",
      "152 4.175559043884277\n",
      "153 5.169882774353027\n",
      "154 4.983944416046143\n",
      "155 3.2195234298706055\n",
      "156 2.246274709701538\n",
      "157 5.284778594970703\n",
      "158 4.414332866668701\n",
      "159 3.4492599964141846\n",
      "160 3.2929089069366455\n",
      "161 5.710289001464844\n",
      "162 1.6668301820755005\n",
      "163 1.3257898092269897\n",
      "164 3.0299782752990723\n",
      "165 9.268481254577637\n",
      "166 3.8608782291412354\n",
      "167 7.817559242248535\n",
      "168 8.860544204711914\n",
      "169 6.129063606262207\n",
      "170 2.9766769409179688\n",
      "171 2.526994228363037\n",
      "172 2.5948057174682617\n",
      "173 7.280127048492432\n",
      "174 5.6946516036987305\n",
      "175 5.336357593536377\n",
      "176 2.4791877269744873\n",
      "177 8.884866714477539\n",
      "178 4.475106716156006\n",
      "179 2.032745361328125\n",
      "180 10.401857376098633\n",
      "181 6.207302570343018\n",
      "182 2.3788890838623047\n",
      "183 3.107234001159668\n",
      "184 9.780716896057129\n",
      "185 6.474930286407471\n",
      "186 4.198608875274658\n",
      "187 4.474043369293213\n",
      "188 1.8451951742172241\n",
      "189 15.098800659179688\n",
      "190 9.537196159362793\n",
      "191 1.5553072690963745\n",
      "192 7.708290100097656\n",
      "193 27.312746047973633\n",
      "194 2.5492427349090576\n",
      "195 0.8334215879440308\n",
      "196 1.100028395652771\n",
      "197 23.88583755493164\n",
      "198 25.015056610107422\n",
      "199 1.3956083059310913\n",
      "200 15.893254280090332\n",
      "201 7.119331359863281\n",
      "202 22.203588485717773\n",
      "203 2.967785596847534\n",
      "204 6.2991108894348145\n",
      "205 13.813591957092285\n",
      "206 9.181497573852539\n",
      "207 0.90043705701828\n",
      "208 2.5055899620056152\n",
      "209 19.99829864501953\n",
      "210 1.1408966779708862\n",
      "211 2.7003860473632812\n",
      "212 1.0808647871017456\n",
      "213 5.0784759521484375\n",
      "214 4.78713321685791\n",
      "215 2.790778160095215\n",
      "216 3.520975112915039\n",
      "217 2.437868118286133\n",
      "218 3.1232240200042725\n",
      "219 4.716874122619629\n",
      "220 1.5115134716033936\n",
      "221 2.4669246673583984\n",
      "222 1.0440527200698853\n",
      "223 0.7943932414054871\n",
      "224 4.898380756378174\n",
      "225 1.9194144010543823\n",
      "226 1.9517515897750854\n",
      "227 1.0708328485488892\n",
      "228 2.9846107959747314\n",
      "229 2.5636680126190186\n",
      "230 3.5700697898864746\n",
      "231 2.7283661365509033\n",
      "232 1.9998834133148193\n",
      "233 6.288666725158691\n",
      "234 0.8900728821754456\n",
      "235 1.493361473083496\n",
      "236 3.6614274978637695\n",
      "237 2.348182201385498\n",
      "238 1.5174760818481445\n",
      "239 1.2545387744903564\n",
      "240 1.538651466369629\n",
      "241 1.5283797979354858\n",
      "242 2.593367099761963\n",
      "243 2.068077564239502\n",
      "244 2.6573328971862793\n",
      "245 2.8059496879577637\n",
      "246 2.789250612258911\n",
      "247 1.8278647661209106\n",
      "248 1.1054285764694214\n",
      "249 0.6214069128036499\n",
      "250 3.558936834335327\n",
      "251 0.5170868039131165\n",
      "252 0.5695268511772156\n",
      "253 1.4080314636230469\n",
      "254 1.6687734127044678\n",
      "255 1.5966085195541382\n",
      "256 1.8609368801116943\n",
      "257 0.8758656978607178\n",
      "258 1.5409241914749146\n",
      "259 1.4443429708480835\n",
      "260 1.259528398513794\n",
      "261 1.316975474357605\n",
      "262 1.0206393003463745\n",
      "263 1.3830511569976807\n",
      "264 0.8312196731567383\n",
      "265 0.7832522988319397\n",
      "266 0.7001824975013733\n",
      "267 0.5171079039573669\n",
      "268 0.7608780860900879\n",
      "269 2.3321073055267334\n",
      "270 0.7492213249206543\n",
      "271 0.5015385746955872\n",
      "272 0.5116440057754517\n",
      "273 1.4890344142913818\n",
      "274 0.810674250125885\n",
      "275 1.4723087549209595\n",
      "276 0.7713836431503296\n",
      "277 1.0014387369155884\n",
      "278 1.9245483875274658\n",
      "279 0.548655092716217\n",
      "280 0.48080089688301086\n",
      "281 0.5816141366958618\n",
      "282 3.1772022247314453\n",
      "283 1.4099746942520142\n",
      "284 1.49962317943573\n",
      "285 1.8915150165557861\n",
      "286 1.605009913444519\n",
      "287 1.312137246131897\n",
      "288 0.2847844660282135\n",
      "289 1.8400331735610962\n",
      "290 1.8883248567581177\n",
      "291 0.2327224761247635\n",
      "292 0.9114663600921631\n",
      "293 0.6707336902618408\n",
      "294 1.7941280603408813\n",
      "295 0.3524354100227356\n",
      "296 0.3950207829475403\n",
      "297 1.4392263889312744\n",
      "298 1.188582420349121\n",
      "299 0.8198367357254028\n",
      "300 0.8503790497779846\n",
      "301 1.1756277084350586\n",
      "302 0.9005414247512817\n",
      "303 0.5761359930038452\n",
      "304 0.9184694290161133\n",
      "305 1.0627102851867676\n",
      "306 0.32823458313941956\n",
      "307 1.3464131355285645\n",
      "308 0.20820927619934082\n",
      "309 0.25619247555732727\n",
      "310 1.7477818727493286\n",
      "311 0.7360484004020691\n",
      "312 0.493656724691391\n",
      "313 0.6112005114555359\n",
      "314 0.10986802726984024\n",
      "315 2.5336356163024902\n",
      "316 0.7076131701469421\n",
      "317 1.0575735569000244\n",
      "318 0.3297809660434723\n",
      "319 1.2822163105010986\n",
      "320 0.884023904800415\n",
      "321 0.08669804781675339\n",
      "322 1.2139806747436523\n",
      "323 0.9833776950836182\n",
      "324 0.21292546391487122\n",
      "325 0.8761008381843567\n",
      "326 0.7470531463623047\n",
      "327 0.07221823930740356\n",
      "328 0.08091810345649719\n",
      "329 1.2009234428405762\n",
      "330 0.678113579750061\n",
      "331 0.7657151222229004\n",
      "332 0.08356302976608276\n",
      "333 0.9236225485801697\n",
      "334 0.08825569599866867\n",
      "335 0.6936907768249512\n",
      "336 0.07659637928009033\n",
      "337 0.7210618257522583\n",
      "338 0.1379740834236145\n",
      "339 0.8812279105186462\n",
      "340 0.38847389817237854\n",
      "341 0.37561148405075073\n",
      "342 0.6450921893119812\n",
      "343 0.36764416098594666\n",
      "344 1.142336368560791\n",
      "345 0.2606552541255951\n",
      "346 0.16170087456703186\n",
      "347 0.7262434363365173\n",
      "348 0.6713383793830872\n",
      "349 0.10940562933683395\n",
      "350 0.43759360909461975\n",
      "351 1.0836721658706665\n",
      "352 0.49013644456863403\n",
      "353 0.6192764043807983\n",
      "354 0.9565967917442322\n",
      "355 0.3843451738357544\n",
      "356 0.8927955031394958\n",
      "357 0.7907313704490662\n",
      "358 0.15067391097545624\n",
      "359 0.7562111616134644\n",
      "360 0.5760303735733032\n",
      "361 0.5333252549171448\n",
      "362 0.906378448009491\n",
      "363 0.6761205792427063\n",
      "364 0.6005216240882874\n",
      "365 0.7682053446769714\n",
      "366 0.1561211496591568\n",
      "367 1.1062763929367065\n",
      "368 0.5382807850837708\n",
      "369 0.646061360836029\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "370 0.7108477354049683\n",
      "371 1.3025782108306885\n",
      "372 0.5687431693077087\n",
      "373 0.12669356167316437\n",
      "374 1.4945311546325684\n",
      "375 0.944771409034729\n",
      "376 0.5871182680130005\n",
      "377 0.4202112853527069\n",
      "378 0.6937492489814758\n",
      "379 0.5822004079818726\n",
      "380 0.5510810017585754\n",
      "381 0.33912143111228943\n",
      "382 0.5979511141777039\n",
      "383 0.6752994060516357\n",
      "384 0.15446896851062775\n",
      "385 0.5838897228240967\n",
      "386 0.28018245100975037\n",
      "387 0.6356022953987122\n",
      "388 0.2259702831506729\n",
      "389 0.49729546904563904\n",
      "390 0.5634748339653015\n",
      "391 0.4511571526527405\n",
      "392 0.13618643581867218\n",
      "393 0.3625476658344269\n",
      "394 0.6296881437301636\n",
      "395 0.6162651181221008\n",
      "396 0.7156220078468323\n",
      "397 0.08067669719457626\n",
      "398 0.07497218996286392\n",
      "399 0.46947574615478516\n",
      "400 0.05823824927210808\n",
      "401 0.054812464863061905\n",
      "402 0.30487683415412903\n",
      "403 0.5188630223274231\n",
      "404 0.24011509120464325\n",
      "405 0.22320348024368286\n",
      "406 0.1118219867348671\n",
      "407 0.17018495500087738\n",
      "408 0.48007476329803467\n",
      "409 0.1521710455417633\n",
      "410 0.10891757160425186\n",
      "411 0.4311083257198334\n",
      "412 0.11156876385211945\n",
      "413 0.1375710666179657\n",
      "414 0.35488301515579224\n",
      "415 0.09440729022026062\n",
      "416 0.31100594997406006\n",
      "417 0.09110146015882492\n",
      "418 0.9457928538322449\n",
      "419 0.21889875829219818\n",
      "420 0.2322666049003601\n",
      "421 0.09544331580400467\n",
      "422 0.09422259777784348\n",
      "423 0.23210197687149048\n",
      "424 0.7685413360595703\n",
      "425 0.2159704715013504\n",
      "426 0.6627539396286011\n",
      "427 0.2437107264995575\n",
      "428 0.45603594183921814\n",
      "429 0.531288206577301\n",
      "430 0.2266312688589096\n",
      "431 0.45839086174964905\n",
      "432 0.4207141101360321\n",
      "433 0.09500540792942047\n",
      "434 0.24853743612766266\n",
      "435 0.5892273187637329\n",
      "436 0.2866324484348297\n",
      "437 0.25401031970977783\n",
      "438 0.25893452763557434\n",
      "439 0.2254890501499176\n",
      "440 0.7016360759735107\n",
      "441 0.17497187852859497\n",
      "442 0.289255827665329\n",
      "443 0.6998990774154663\n",
      "444 0.06676200777292252\n",
      "445 0.057566795498132706\n",
      "446 0.04954655468463898\n",
      "447 0.042065247893333435\n",
      "448 0.6602599024772644\n",
      "449 0.02503911405801773\n",
      "450 0.29209503531455994\n",
      "451 0.42001476883888245\n",
      "452 0.31218221783638\n",
      "453 0.37083300948143005\n",
      "454 0.5052850246429443\n",
      "455 0.2807874381542206\n",
      "456 0.5272316932678223\n",
      "457 0.3653412461280823\n",
      "458 0.041884735226631165\n",
      "459 0.038469165563583374\n",
      "460 0.3338800072669983\n",
      "461 0.03671689331531525\n",
      "462 0.03517509624361992\n",
      "463 0.03202994167804718\n",
      "464 0.7529798746109009\n",
      "465 0.1868058294057846\n",
      "466 0.048153381794691086\n",
      "467 0.060195211321115494\n",
      "468 0.4677262008190155\n",
      "469 0.027607697993516922\n",
      "470 0.018481099978089333\n",
      "471 0.7515113353729248\n",
      "472 0.09024880826473236\n",
      "473 0.15346695482730865\n",
      "474 0.7065528035163879\n",
      "475 0.09334680438041687\n",
      "476 0.03496542572975159\n",
      "477 0.02390982396900654\n",
      "478 0.054738420993089676\n",
      "479 0.08889191597700119\n",
      "480 0.4580038785934448\n",
      "481 0.8226954340934753\n",
      "482 0.2960473895072937\n",
      "483 0.647020161151886\n",
      "484 1.0846458673477173\n",
      "485 0.42106977105140686\n",
      "486 0.025976162403821945\n",
      "487 0.8289820551872253\n",
      "488 0.8667175769805908\n",
      "489 0.45930516719818115\n",
      "490 0.3914187252521515\n",
      "491 0.17959195375442505\n",
      "492 1.1600592136383057\n",
      "493 0.6354894638061523\n",
      "494 0.06596527993679047\n",
      "495 1.435815453529358\n",
      "496 0.3067678213119507\n",
      "497 0.27009594440460205\n",
      "498 0.37564435601234436\n",
      "499 0.47951456904411316\n"
     ]
    }
   ],
   "source": [
    "class DynamicNet(torch.nn.Module):\n",
    "    def __init__(self, D_in, H, D_out):\n",
    "        super().__init__()\n",
    "        self.input_fc = torch.nn.Linear(D_in, H)\n",
    "        self.middle_fc = torch.nn.Linear(H, H)\n",
    "        self.output_fc = torch.nn.Linear(H, D_out)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        h_relu = self.input_fc(x).clamp(min=0)\n",
    "        for _ in range(random.randint(0, 3)):\n",
    "            h_relu = self.middle_fc(h_relu).clamp(min=0)\n",
    "        y_pred = self.output_fc(h_relu)\n",
    "        return y_pred\n",
    "    \n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "x = torch.randn(N, D_in)\n",
    "y = torch.randn(N, D_out)\n",
    "\n",
    "model = DynamicNet(D_in, H, D_out)\n",
    "criterion = torch.nn.MSELoss(size_average=False)\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)\n",
    "\n",
    "for t in range(500):\n",
    "    y_pred = model(x)\n",
    "    loss = criterion(y_pred, y)\n",
    "    print(t, loss.item())\n",
    "    \n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
